# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import sys
import numpy as np
from numpy import random as rand
import pandas as pd
from multiprocessing import Pool
import functools, multiprocessing
from datetime import datetime
import math
from collections import OrderedDict
import collections
import time
from scipy.stats import multivariate_normal
from scipy.stats import norm
from experimental_design import get_XYdesign, get_oracle
from VarianceSimsTools_BAI import two_spaces        



def AdaptiveXY(theta, X, Z, Sigma, heteroskedastic=False, oracle=False, delta = 0.05, log=False, seed=0, MVT = False):
    #Best-arm identification master function- can call for perform H-RAGE, RAGE and the oracle allocations.
    np.random.seed()
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Current Time =", current_time)
    
    #setup
    n, d = X.shape
    z_n = Z.shape[0]
    iterations = 1000 #iterations for the Franke-Wolfe Algorithm    
    
    #compute the outer products ahead of time
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 
    
    #the truth
    opt_arm_index = np.argmax(Z@theta)
    opt_arm =  Z[opt_arm_index]
    potential_opt_set = list(range(n))
    total_arms = list(range(n))
    sample_counts = np.zeros(n)
    t = 0 

    #initialize lambda vector and best arms vector
    lambda_vec = np.array([1/n]*n)
    sigmas = np.array([X[i].T@Sigma@X[i] for i in range(X.shape[0])])
    sigma_max = np.max(sigmas)
    sigma_min = np.min(sigmas)
    initial_t = 0
    
    if not heteroskedastic:
        sigmas = np.max(sigmas)*np.ones(len(sigmas)) 
        estimated_sigmas = np.max(sigmas)*np.ones(len(sigmas)) 
    if heteroskedastic and not oracle:
        kappa = sigma_max/sigma_min
        c = np.sqrt(128*sigma_max/(np.log(n)+1))+np.sqrt(2/d)+np.sqrt(10/(np.log(n+1)+1+np.log(12/delta))) 
        #matches Equation (16) with epsilon = 1/60
        t = int(4*c**2*d**2*np.log(12/delta)*kappa**2)
        print(t)
        c = 256*(11/10)
        t = int(4*c*d**2*np.log(12/delta)*kappa**2)
        print(t)
        initial_t = t
        _, Sigma_hat = two_spaces(X, Sigma, theta, 0, 1000, t)
        estimated_sigmas = np.array([X[i].T@Sigma_hat@X[i] for i in range(X.shape[0])])
        #print("sigmas")
        #print(sigmas)
        #estimated_sigmas = sigmas.copy()
    if heteroskedastic and oracle:
        estimated_sigmas = sigmas
    
    #print(estimated_sigmas)

    #find the oracle lambda distribution 
    _, oracle_value, oracle_lambda = get_oracle(outers, X, Z, 10000, opt_arm_index, theta, 0, lambda_vec.copy(), sigmas)
    print(oracle_lambda)
    #print("oracle lambda")
    #print(oracle_lambda)
     
    ell = 1
    while len(potential_opt_set) > 1:

        V = Z[potential_opt_set] #optimal arm set
        if oracle:
            eps_ell = 2**(-ell)
            potential_opt_set = list(range(n))
            lambda_vec = oracle_lambda
            #max here?
            sample_size = int(np.ceil(2*ell*oracle_value*np.log(4*ell**2*z_n/delta)))
        else:
            #print(V)
            #print(potential_opt_set)
            #print(X)
            _, design_value, lambda_vec = get_XYdesign(outers, V.copy(), X.copy(), iterations, 0.000001, lambda_vec.copy(), estimated_sigmas.copy()) 
            #print(design_value)
            eps_ell = 2**(-ell)
            if heteroskedastic:
                sample_size = int(np.ceil(3*eps_ell**(-2)*design_value*np.log(8*ell**2*z_n**2/delta)))
            else:
                sample_size = int(np.ceil(2*eps_ell**(-2)*design_value*np.log(8*ell**2*z_n**2/delta)))
        #print("lambda_vec")
        #print(lambda_vec)
        
        print("sample_size")
        print(sample_size)
        
        #arm_choice_index = rand.choice(total_arms, size=sample_size, p = lambda_vec/np.sum(lambda_vec))        
        #COULD MAKE THIS INTO A LOOP
        
        t_before = t
        
        if sample_size > 10**6:
            
            sample_size_left = sample_size
            rounds = int(np.ceil(sample_size/(10**6)))
            
            theta_hat_prec = np.zeros((d,d))
            XtY = np.zeros(d)
            
            for i in range(rounds):
                
                round_time =  time.perf_counter()
                
                if sample_size_left > 10**6:
                    round_sample_size = 10**6
                    sample_size_left -= round_sample_size 
                else:
                    round_sample_size = sample_size_left
                
                print(round_sample_size)
                print(sample_size_left)
                    
                X_samples = np.ceil(round_sample_size*(lambda_vec/np.sum(lambda_vec)))

                arm_choice_index = sum([ [i]*int(X_samples[i]) for i in total_arms], []) #high mem step
                round_sample_size = len(arm_choice_index)

                t += round_sample_size

                arm_choice = X[arm_choice_index] #high mem step
                reward = arm_choice @ theta + np.random.normal(0, sigmas[arm_choice_index], round_sample_size)
                
                theta_hat_prec += np.sum([np.outer(X[arm_choice_i], 
                                                   X[arm_choice_i])/estimated_sigmas[arm_choice_i]
                                      for arm_choice_i in arm_choice_index], axis=0)
                
                XtY += np.sum([X[arm_choice_i]*reward[i]/estimated_sigmas[arm_choice_i]
                           for i, arm_choice_i in enumerate(arm_choice_index)], axis=0)
                
                print("round_time")
                print( time.perf_counter() - round_time)
        
        else:
            
            X_samples = np.ceil(sample_size*(lambda_vec/np.sum(lambda_vec)))

            arm_choice_index = sum([ [i]*int(X_samples[i]) for i in total_arms], []) #high mem step
            sample_size = len(arm_choice_index)

            t += sample_size

            arm_choice = X[arm_choice_index] #high mem step
            reward = arm_choice @ theta + np.random.normal(0, sigmas[arm_choice_index], sample_size)
            theta_hat_prec = np.sum([np.outer(X[arm_choice_i], X[arm_choice_i])/estimated_sigmas[arm_choice_i]
                                      for arm_choice_i in arm_choice_index], axis=0)

            XtY = np.sum([X[arm_choice_i]*reward[i]/estimated_sigmas[arm_choice_i]
                           for i, arm_choice_i in enumerate(arm_choice_index)], axis=0)
            
            
            
        print("t_after")
        print(t-t_before)
        
        
            

        
        #for loop
        
        #unique, counts = np.unique(arm_choice_index, return_counts=True)
        #round_sample_counts = dict(zip(unique, counts))
        #for key in round_sample_counts:
        #    sample_counts[key] += round_sample_counts[key]
           
        theta_hat_cov = np.linalg.pinv(theta_hat_prec)
        theta_hat = theta_hat_cov @ XtY
        empirical_best_value = np.max(V@theta_hat)
        estimated_gaps = empirical_best_value-V@theta_hat
        
        if oracle:
            stop = True
            #why index this way?
            arm = Z[opt_arm_index, :, None]

            for arm_idx_prime in range(n):
                #continue if it's the optimal arm
                if opt_arm_index == arm_idx_prime:
                    continue
                
                arm_prime = Z[arm_idx_prime, :, None]
                y = arm - arm_prime
                if  0 >= y.T@theta_hat:
                        stop = False
                        break 
                        
            if stop:
                potential_opt_set = [opt_arm_index]
        else:
            #print(eps_ell)
            eliminated_arms = np.where(estimated_gaps > eps_ell)[0]
            #mistake here
            potential_opt_set = np.delete(potential_opt_set, eliminated_arms)
           
        if log:
            if oracle:
                print(now.strftime("%H:%M:%S"), 'sample size=', t)
            else:
                print(now.strftime("%H:%M:%S"), 'sample size=', t - initial_t, 'round ell', 
                      ell, 'gaps eliminated', eps_ell*2, 'potentially optimal arms',  potential_opt_set,
                      'estimated gaps', estimated_gaps) #, 'allocation',  sample_counts/t)

        #Increment
        ell+=1
    return potential_opt_set[0], 2*oracle_value*np.log(n/delta), t


                              
